{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![umap in atlas](https://docs.nomic.ai/img/umap-with-nomic-atlas.png)\n",
    "\n",
    "# Visualizing MNIST Training Dynamics with Nomic Atlas\n",
    "\n",
    "UMAP is available as a projection in [Nomic Atlas](https://atlas.nomic.ai), which creates interactive maps of your data with AI analysis, vector search APIs, and additional resources like topic label generation.\n",
    "\n",
    "![mnist embeddings in atlas](https://assets.nomicatlas.com/mnist-training-embeddings-umap-short.gif)\n",
    "\n",
    "Nomic Atlas automatically generates embeddings for your data and allows you to explore large datasets in a web browser. Atlas provides:\n",
    "\n",
    "* In-browser analysis of your UMAP data with the [Atlas Analyst](https://docs.nomic.ai/atlas/data-maps/atlas-analyst)\n",
    "* Vector search over your UMAP data using the [Nomic API](https://docs.nomic.ai/atlas/data-maps/guides/vector-search-over-your-data)\n",
    "* Interactive features like zooming, recoloring, searching, and filtering in the [Nomic Atlas data map](https://docs.nomic.ai/atlas/data-maps/controls)\n",
    "* Scalability for millions of data points\n",
    "* Rich information display on hover\n",
    "* Shareable UMAPs via URL links to your embeddings and data maps in Atlas\n",
    "\n",
    "This example demonstrates how to use [Atlas](https://docs.nomic.ai/atlas/embeddings-and-retrieval/guides/using-umap-with-atlas) to visualize the training dynamics of your neural network using embeddings and UMAP.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "1. Get the python package with `pip instll nomic`\n",
    "\n",
    "2. Get A Nomic API key [here](https://atlas.nomic.ai/cli-login)\n",
    "\n",
    "3. Run `nomic login nk-...` in a terminal window or\n",
    "\n",
    "```python\n",
    "import nomic\n",
    "nomic.login('nk-...')\n",
    "```\n",
    "\n",
    "at the top of your code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We set up some imports, hyperparameters, and a helper function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cpu\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "import numpy as np\n",
    "import time\n",
    "from PIL import Image\n",
    "import base64\n",
    "import io\n",
    "\n",
    "NUM_EPOCHS = 20\n",
    "LEARNING_RATE = 3e-6\n",
    "BATCH_SIZE = 128\n",
    "NUM_VIS_SAMPLES = 2000\n",
    "EMBEDDING_DIM = 128\n",
    "ATLAS_DATASET_NAME = \"mnist_training_embeddings\"\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {device}\\n\")\n",
    "\n",
    "def tensor_to_html(tensor):\n",
    "    \"\"\"Helper function to convert image tensors to HTML for rendering in Nomic Atlas\"\"\"\n",
    "    # Denormalize the image\n",
    "    img = torch.clamp(tensor.clone().detach().cpu().squeeze(0) * 0.3081 + 0.1307, 0, 1)\n",
    "    img_pil = Image.fromarray((img.numpy() * 255).astype('uint8'), mode='L')\n",
    "    buffered = io.BytesIO()\n",
    "    img_pil.save(buffered, format=\"PNG\")\n",
    "    img_str = base64.b64encode(buffered.getvalue()).decode()\n",
    "    return f'<img src=\"data:image/png;base64,{img_str}\" width=\"28\" height=\"28\">'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We setup a CNN image classifier for MNIST data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training on 60000 samples, visualizing 2000 test samples per epoch.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "class MNIST_CNN(nn.Module):\n",
    "    def __init__(self, embedding_dim=128):\n",
    "        super(MNIST_CNN, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)\n",
    "        self.relu1 = nn.ReLU()\n",
    "        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)\n",
    "        self.relu2 = nn.ReLU()\n",
    "        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.fc1 = nn.Linear(64 * 7 * 7, embedding_dim)\n",
    "        self.relu3 = nn.ReLU()\n",
    "        self.fc2 = nn.Linear(embedding_dim, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool1(self.relu1(self.conv1(x)))\n",
    "        x = self.pool2(self.relu2(self.conv2(x)))\n",
    "        x = self.flatten(x)\n",
    "        embeddings = self.relu3(self.fc1(x))\n",
    "        output = self.fc2(embeddings)\n",
    "        return output, embeddings\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.1307,), (0.3081,))\n",
    "])\n",
    "\n",
    "train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "persistent_workers_flag = True if device.type not in ['mps', 'cpu'] else False\n",
    "num_workers_val = 2 if persistent_workers_flag else 0\n",
    "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers_val, persistent_workers=persistent_workers_flag if num_workers_val > 0 else False)\n",
    "vis_indices = list(range(NUM_VIS_SAMPLES))\n",
    "vis_subset = Subset(test_dataset, vis_indices)\n",
    "test_loader_for_vis = DataLoader(vis_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers_val, persistent_workers=persistent_workers_flag if num_workers_val > 0 else False)\n",
    "print(f\"Training on {len(train_dataset)} samples, visualizing {NUM_VIS_SAMPLES} test samples per epoch.\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect Embeddings During Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We save embeddings from the last layer at each iteration to track the change in the model's output distribution over the course of training. This is what Atlas is uniquely well suited to visualize."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Batch [200/469], Avg Loss: 2.2656\n",
      "Epoch [1/20], Batch [400/469], Avg Loss: 2.1646\n",
      "Epoch 1/20 training finished in 15.08s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 1.\n",
      "\n",
      "Epoch [2/20], Batch [200/469], Avg Loss: 1.9691\n",
      "Epoch [2/20], Batch [400/469], Avg Loss: 1.7807\n",
      "Epoch 2/20 training finished in 14.78s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 2.\n",
      "\n",
      "Epoch [3/20], Batch [200/469], Avg Loss: 1.5193\n",
      "Epoch [3/20], Batch [400/469], Avg Loss: 1.3360\n",
      "Epoch 3/20 training finished in 14.75s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 3.\n",
      "\n",
      "Epoch [4/20], Batch [200/469], Avg Loss: 1.1200\n",
      "Epoch [4/20], Batch [400/469], Avg Loss: 0.9892\n",
      "Epoch 4/20 training finished in 14.69s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 4.\n",
      "\n",
      "Epoch [5/20], Batch [200/469], Avg Loss: 0.8479\n",
      "Epoch [5/20], Batch [400/469], Avg Loss: 0.7668\n",
      "Epoch 5/20 training finished in 14.80s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 5.\n",
      "\n",
      "Epoch [6/20], Batch [200/469], Avg Loss: 0.6828\n",
      "Epoch [6/20], Batch [400/469], Avg Loss: 0.6265\n",
      "Epoch 6/20 training finished in 14.86s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 6.\n",
      "\n",
      "Epoch [7/20], Batch [200/469], Avg Loss: 0.5677\n",
      "Epoch [7/20], Batch [400/469], Avg Loss: 0.5345\n",
      "Epoch 7/20 training finished in 14.76s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 7.\n",
      "\n",
      "Epoch [8/20], Batch [200/469], Avg Loss: 0.4923\n",
      "Epoch [8/20], Batch [400/469], Avg Loss: 0.4709\n",
      "Epoch 8/20 training finished in 14.68s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 8.\n",
      "\n",
      "Epoch [9/20], Batch [200/469], Avg Loss: 0.4373\n",
      "Epoch [9/20], Batch [400/469], Avg Loss: 0.4255\n",
      "Epoch 9/20 training finished in 15.03s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 9.\n",
      "\n",
      "Epoch [10/20], Batch [200/469], Avg Loss: 0.3979\n",
      "Epoch [10/20], Batch [400/469], Avg Loss: 0.3861\n",
      "Epoch 10/20 training finished in 14.76s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 10.\n",
      "\n",
      "Epoch [11/20], Batch [200/469], Avg Loss: 0.3686\n",
      "Epoch [11/20], Batch [400/469], Avg Loss: 0.3590\n",
      "Epoch 11/20 training finished in 14.97s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 11.\n",
      "\n",
      "Epoch [12/20], Batch [200/469], Avg Loss: 0.3402\n",
      "Epoch [12/20], Batch [400/469], Avg Loss: 0.3371\n",
      "Epoch 12/20 training finished in 14.83s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 12.\n",
      "\n",
      "Epoch [13/20], Batch [200/469], Avg Loss: 0.3214\n",
      "Epoch [13/20], Batch [400/469], Avg Loss: 0.3154\n",
      "Epoch 13/20 training finished in 14.89s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 13.\n",
      "\n",
      "Epoch [14/20], Batch [200/469], Avg Loss: 0.3024\n",
      "Epoch [14/20], Batch [400/469], Avg Loss: 0.3010\n",
      "Epoch 14/20 training finished in 14.86s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 14.\n",
      "\n",
      "Epoch [15/20], Batch [200/469], Avg Loss: 0.2897\n",
      "Epoch [15/20], Batch [400/469], Avg Loss: 0.2917\n",
      "Epoch 15/20 training finished in 14.23s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 15.\n",
      "\n",
      "Epoch [16/20], Batch [200/469], Avg Loss: 0.2734\n",
      "Epoch [16/20], Batch [400/469], Avg Loss: 0.2721\n",
      "Epoch 16/20 training finished in 14.19s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 16.\n",
      "\n",
      "Epoch [17/20], Batch [200/469], Avg Loss: 0.2644\n",
      "Epoch [17/20], Batch [400/469], Avg Loss: 0.2583\n",
      "Epoch 17/20 training finished in 14.23s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 17.\n",
      "\n",
      "Epoch [18/20], Batch [200/469], Avg Loss: 0.2535\n",
      "Epoch [18/20], Batch [400/469], Avg Loss: 0.2503\n",
      "Epoch 18/20 training finished in 15.24s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 18.\n",
      "\n",
      "Epoch [19/20], Batch [200/469], Avg Loss: 0.2427\n",
      "Epoch [19/20], Batch [400/469], Avg Loss: 0.2391\n",
      "Epoch 19/20 training finished in 14.83s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 19.\n",
      "\n",
      "Epoch [20/20], Batch [200/469], Avg Loss: 0.2373\n",
      "Epoch [20/20], Batch [400/469], Avg Loss: 0.2326\n",
      "Epoch 20/20 training finished in 15.11s.\n",
      "\n",
      "Collected 2000 embeddings for visualization in epoch 20.\n",
      "\n",
      "Total training and embedding extraction time: 303.11s\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model = MNIST_CNN(embedding_dim=EMBEDDING_DIM).to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
    "all_embeddings_list = []\n",
    "all_metadata_list = []\n",
    "all_images_html = []\n",
    "overall_start_time = time.time()\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    epoch_start_time = time.time()\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs, _ = model(data)\n",
    "        loss = criterion(outputs, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running_loss += loss.item()\n",
    "        if (batch_idx + 1) % 200 == 0:\n",
    "            print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Batch [{batch_idx+1}/{len(train_loader)}], Avg Loss: {running_loss / 200:.4f}')\n",
    "            running_loss = 0.0\n",
    "    print(f\"Epoch {epoch+1}/{NUM_EPOCHS} training finished in {time.time() - epoch_start_time:.2f}s.\\n\")\n",
    "    model.eval()\n",
    "    vis_samples_collected_this_epoch = 0\n",
    "    image_offset_in_vis_subset = 0 \n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader_for_vis:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            _, embeddings_batch = model(data)\n",
    "            for i in range(embeddings_batch.size(0)):\n",
    "                original_idx_in_subset = image_offset_in_vis_subset + i \n",
    "                if original_idx_in_subset >= NUM_VIS_SAMPLES:\n",
    "                    continue\n",
    "                all_embeddings_list.append(embeddings_batch[i].cpu().numpy())                \n",
    "                img_html = tensor_to_html(data[i])\n",
    "                all_images_html.append(img_html)\n",
    "                all_metadata_list.append({\n",
    "                    'id': f'vis_img_{original_idx_in_subset}_epoch_{epoch}',\n",
    "                    'epoch': epoch,\n",
    "                    'label': f'Digit: {target[i].item()}',\n",
    "                    'vis_sample_idx': original_idx_in_subset,\n",
    "                    'image_html': img_html\n",
    "                })\n",
    "                vis_samples_collected_this_epoch += 1\n",
    "            image_offset_in_vis_subset += embeddings_batch.size(0)\n",
    "            if vis_samples_collected_this_epoch >= NUM_VIS_SAMPLES: \n",
    "                break\n",
    "    print(f\"Collected {vis_samples_collected_this_epoch} embeddings for visualization in epoch {epoch+1}.\\n\")\n",
    "total_script_time = time.time() - overall_start_time\n",
    "print(f\"Total training and embedding extraction time: {total_script_time:.2f}s\\n\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Atlas Dataset and Upload Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2025-05-11 16:50:19.819\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mnomic.dataset\u001b[0m:\u001b[36m_create_project\u001b[0m:\u001b[36m867\u001b[0m - \u001b[1mOrganization name: `nomic`\u001b[0m\n",
      "\u001b[32m2025-05-11 16:50:20.486\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mnomic.dataset\u001b[0m:\u001b[36m_create_project\u001b[0m:\u001b[36m895\u001b[0m - \u001b[1mCreating dataset `mnist-training-embeddings`\u001b[0m\n",
      "100%|██████████| 8/8 [00:29<00:00,  3.71s/it]\n",
      "\u001b[32m2025-05-11 16:50:50.784\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mnomic.dataset\u001b[0m:\u001b[36m_add_data\u001b[0m:\u001b[36m1702\u001b[0m - \u001b[1mUpload succeeded.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "from nomic import AtlasDataset\n",
    "\n",
    "dataset = AtlasDataset(\"mnist-training-embeddings\")\n",
    "dataset.add_data(data=all_metadata_list, embeddings=np.array(all_embeddings_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Data Map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2025-05-11 16:50:52.419\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mnomic.dataset\u001b[0m:\u001b[36mcreate_index\u001b[0m:\u001b[36m1289\u001b[0m - \u001b[1mCreated map `0196c11d-6611-46f0-38ab-b0e9778ad1fb` in dataset `nomic/mnist-training-embeddings`: https://atlas.nomic.ai/data/nomic/mnist-training-embeddings\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Atlas Projection 0196c11d-6611-46f0-38ab-b0e9778ad1fb. Status Building Map. <a target=\"_blank\" href=\"https://atlas.nomic.ai/data/nomic/mnist-training-embeddings/map\">view online</a>"
      ],
      "text/plain": [
       "0196c11d-6611-46f0-38ab-b0e9778ad1fb: https://atlas.nomic.ai/data/nomic/mnist-training-embeddings"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.create_index(projection='umap', topic_model=False) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Your map will be available in your [Atlas Dashboard](https://atlas.nomic.ai/data)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "atlas",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
